// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { InferenceSession, Tensor } from 'onnxruntime-common';

import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';
import { setRunOptions } from './run-options';
import { setSessionOptions } from './session-options';
import {
  dataLocationStringToEnum,
  tensorDataTypeEnumToString,
  tensorDataTypeStringToEnum,
  tensorTypeToTypedArrayConstructor,
} from './wasm-common';
import { prepareInputOutputTensor } from './wasm-core-impl';
import { getInstance } from './wasm-factory';
import { checkLastError } from './wasm-utils';

const NO_TRAIN_FUNCS_MSG =
  "Built without training API's enabled. Use the onnxruntime-web/training import for training " +
  'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
  'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';

/**
 * Runs the checkLastError function which will throw an error, if the provided error code matches the specified
 * pattern for an error code.
 * @param errCode number to evaluated for if it's an error
 * @param message message to pass into checkLastError
 * @param checkNeqZero when true, treats not equal to zero as an error.
 *                     When false, treats equal to zero as an error.
 */
const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => {
  if (checkNeqZero && errCode !== 0) {
    checkLastError(message);
  } else if (!checkNeqZero && errCode === 0) {
    checkLastError(message);
  }
};

export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => {
  const wasm = getInstance();

  const [checkpointDataOffset, checkpointDataLength] = checkpointData;
  let checkpointHandle = 0;

  try {
    if (wasm._OrtTrainingLoadCheckpoint) {
      checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength);
    } else {
      throw new Error(NO_TRAIN_FUNCS_MSG);
    }

    ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false);
    return checkpointHandle;
  } catch (e) {
    if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
      wasm._OrtTrainingReleaseCheckpoint(checkpointHandle);
    }
    throw e;
  } finally {
    // free buffer from wasm heap
    wasm._OrtFree(checkpointData[0]);
  }
};

const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => {
  const wasm = getInstance();
  const stack = wasm.stackSave();
  try {
    const dataOffset = wasm.stackAlloc(8);
    if (wasm._OrtTrainingGetModelInputOutputCount) {
      const errorCode = wasm._OrtTrainingGetModelInputOutputCount(
        trainingSessionId,
        dataOffset,
        dataOffset + 4,
        isEvalModel,
      );
      ifErrCodeCheckLastError(errorCode, "Can't get session input/output count.");
      return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
    } else {
      throw new Error(NO_TRAIN_FUNCS_MSG);
    }
  } finally {
    wasm.stackRestore(stack);
  }
};

const getModelInputOutputNamesLoop = (
  trainingSessionId: number,
  count: number,
  isInput: boolean,
  isEvalModel: boolean,
): string[] => {
  const names = [];
  const wasm = getInstance();

  for (let i = 0; i < count; i++) {
    if (wasm._OrtTrainingGetModelInputOutputName) {
      const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
      ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false);

      names.push(wasm.UTF8ToString(name));
      wasm._free(name);
    } else {
      throw new Error(NO_TRAIN_FUNCS_MSG);
    }
  }
  return names;
};

export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => {
  let inputNames: string[] = [];
  let outputNames: string[] = [];

  const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel);

  inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel);
  outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel);

  return [inputNames, outputNames];
};

export const createTrainingSessionHandle = (
  checkpointHandle: number,
  trainModelData: SerializableInternalBuffer,
  evalModelData: SerializableInternalBuffer,
  optimizerModelData: SerializableInternalBuffer,
  options: InferenceSession.SessionOptions,
): number => {
  const wasm = getInstance();

  let trainingSessionHandle = 0;
  let sessionOptionsHandle = 0;
  let allocs: number[] = [];

  try {
    [sessionOptionsHandle, allocs] = setSessionOptions(options);
    if (wasm._OrtTrainingCreateSession) {
      trainingSessionHandle = wasm._OrtTrainingCreateSession(
        sessionOptionsHandle,
        checkpointHandle,
        trainModelData[0],
        trainModelData[1],
        evalModelData[0],
        evalModelData[1],
        optimizerModelData[0],
        optimizerModelData[1],
      );
    } else {
      throw new Error(NO_TRAIN_FUNCS_MSG);
    }

    ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false);
    return trainingSessionHandle;
  } catch (e) {
    if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
      wasm._OrtTrainingReleaseSession(trainingSessionHandle);
    }
    throw e;
  } finally {
    wasm._free(trainModelData[0]);
    wasm._free(evalModelData[0]);
    wasm._free(optimizerModelData[0]);

    if (sessionOptionsHandle !== 0) {
      wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
    }
    allocs.forEach((alloc) => wasm._free(alloc));
  }
};

/**
 * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the
 * WASM tensors.
 *
 * @param trainingSessionId
 * @param indices for each tensor, the index of the input or output name that the tensor corresponds with
 * @param tensors list of TensorMetaData
 * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting
 *                      handles of the allocated tensors on the heap
 * @param inputOutputAllocs modified in-place by this method
 * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor
 */
const createAndAllocateTensors = (
  trainingSessionId: number,
  indices: number[],
  tensors: Array<TensorMetadata | null>,
  tensorHandles: number[],
  inputOutputAllocs: number[],
  indexAdd: number,
) => {
  const count = indices.length;

  // creates the tensors
  for (let i = 0; i < count; i++) {
    prepareInputOutputTensor(tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]);
  }

  // moves to heap
  const wasm = getInstance();
  const valuesOffset = wasm.stackAlloc(count * 4);
  let valuesIndex = valuesOffset / 4;
  for (let i = 0; i < count; i++) {
    wasm.HEAPU32[valuesIndex++] = tensorHandles[i];
  }

  return valuesOffset;
};

/**
 * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information
 * associated with the tensor handle.
 *
 * @param outputValuesOffset
 * @param outputCount
 * @returns list of TensorMetadata retrieved from the output handles.
 */
const moveOutputToTensorMetadataArr = (
  outputValuesOffset: number,
  outputCount: number,
  outputTensorHandles: number[],
  outputTensors: Array<TensorMetadata | null>,
) => {
  const wasm = getInstance();
  const output: TensorMetadata[] = [];

  for (let i = 0; i < outputCount; i++) {
    const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
    if (tensor === outputTensorHandles[i]) {
      // output tensor is pre-allocated. no need to copy data.
      output.push(outputTensors[i]!);
      continue;
    }

    const beforeGetTensorDataStack = wasm.stackSave();
    // stack allocate 4 pointer value
    const tensorDataOffset = wasm.stackAlloc(4 * 4);

    let type: Tensor.Type | undefined,
      dataOffset = 0;
    try {
      const errorCode = wasm._OrtGetTensorData(
        tensor,
        tensorDataOffset,
        tensorDataOffset + 4,
        tensorDataOffset + 8,
        tensorDataOffset + 12,
      );
      ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`);

      let tensorDataIndex = tensorDataOffset / 4;
      const dataType = wasm.HEAPU32[tensorDataIndex++];
      dataOffset = wasm.HEAPU32[tensorDataIndex++];
      const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
      const dimsLength = wasm.HEAPU32[tensorDataIndex++];
      const dims = [];
      for (let i = 0; i < dimsLength; i++) {
        dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
      }
      wasm._OrtFree(dimsOffset);

      const size = dims.reduce((a, b) => a * b, 1);
      type = tensorDataTypeEnumToString(dataType);

      if (type === 'string') {
        const stringData: string[] = [];
        let dataIndex = dataOffset / 4;
        for (let i = 0; i < size; i++) {
          const offset = wasm.HEAPU32[dataIndex++];
          const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
          stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
        }
        output.push([type, dims, stringData, 'cpu']);
      } else {
        const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
        const data = new typedArrayConstructor(size);
        new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
          wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength),
        );
        output.push([type, dims, data, 'cpu']);
      }
    } finally {
      wasm.stackRestore(beforeGetTensorDataStack);
      if (type === 'string' && dataOffset) {
        wasm._free(dataOffset);
      }
      wasm._OrtReleaseTensor(tensor);
    }
  }

  return output;
};

export const lazyResetGrad = async (trainingSessionId: number): Promise<void> => {
  const wasm = getInstance();

  if (wasm._OrtTrainingLazyResetGrad) {
    const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId);
    ifErrCodeCheckLastError(errorCode, "Can't call lazyResetGrad.");
  } else {
    throw new Error(NO_TRAIN_FUNCS_MSG);
  }
};

export const runTrainStep = async (
  trainingSessionId: number,
  inputIndices: number[],
  inputTensors: TensorMetadata[],
  outputIndices: number[],
  outputTensors: Array<TensorMetadata | null>,
  options: InferenceSession.RunOptions,
): Promise<TensorMetadata[]> => {
  const wasm = getInstance();

  const inputCount = inputIndices.length;
  const outputCount = outputIndices.length;

  let runOptionsHandle = 0;
  let runOptionsAllocs: number[] = [];

  const inputTensorHandles: number[] = [];
  const outputTensorHandles: number[] = [];
  const inputOutputAllocs: number[] = [];

  const beforeRunStack = wasm.stackSave();

  try {
    // prepare parameters by moving them to heap
    [runOptionsHandle, runOptionsAllocs] = setRunOptions(options);

    // handle inputs -- you don't want anything added to the index
    const inputValuesOffset = createAndAllocateTensors(
      trainingSessionId,
      inputIndices,
      inputTensors,
      inputTensorHandles,
      inputOutputAllocs,
      0,
    );
    // handle outputs
    // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
    const outputValuesOffset = createAndAllocateTensors(
      trainingSessionId,
      outputIndices,
      outputTensors,
      outputTensorHandles,
      inputOutputAllocs,
      inputCount,
    );

    if (wasm._OrtTrainingRunTrainStep) {
      const errorCode = wasm._OrtTrainingRunTrainStep(
        trainingSessionId,
        inputValuesOffset,
        inputCount,
        outputValuesOffset,
        outputCount,
        runOptionsHandle,
      );
      ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
    } else {
      throw new Error(NO_TRAIN_FUNCS_MSG);
    }

    return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
  } finally {
    wasm.stackRestore(beforeRunStack);

    inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
    outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
    inputOutputAllocs.forEach((p) => wasm._free(p));

    if (runOptionsHandle !== 0) {
      wasm._OrtReleaseRunOptions(runOptionsHandle);
    }
    runOptionsAllocs.forEach((p) => wasm._free(p));
  }
};

export const runOptimizerStep = async (
  trainingSessionId: number,
  options: InferenceSession.RunOptions,
): Promise<void> => {
  const wasm = getInstance();

  let runOptionsHandle = 0;
  let runOptionsAllocs: number[] = [];

  try {
    [runOptionsHandle, runOptionsAllocs] = setRunOptions(options);

    if (wasm._OrtTrainingOptimizerStep) {
      const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle);
      ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer');
    } else {
      throw new Error(NO_TRAIN_FUNCS_MSG);
    }
  } finally {
    if (runOptionsHandle !== 0) {
      wasm._OrtReleaseRunOptions(runOptionsHandle);
    }
    runOptionsAllocs.forEach((p) => wasm._free(p));
  }
};

export const runEvalStep = async (
  trainingSessionId: number,
  inputIndices: number[],
  inputTensors: TensorMetadata[],
  outputIndices: number[],
  outputTensors: Array<TensorMetadata | null>,
  options: InferenceSession.RunOptions,
): Promise<TensorMetadata[]> => {
  const wasm = getInstance();

  const inputCount = inputIndices.length;
  const outputCount = outputIndices.length;

  let runOptionsHandle = 0;
  let runOptionsAllocs: number[] = [];

  const inputTensorHandles: number[] = [];
  const outputTensorHandles: number[] = [];
  const inputOutputAllocs: number[] = [];

  const beforeRunStack = wasm.stackSave();

  try {
    // prepare parameters by moving them to heap
    [runOptionsHandle, runOptionsAllocs] = setRunOptions(options);

    // handle inputs -- you don't want anything added to the index
    const inputValuesOffset = createAndAllocateTensors(
      trainingSessionId,
      inputIndices,
      inputTensors,
      inputTensorHandles,
      inputOutputAllocs,
      0,
    );
    // handle outputs
    // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
    const outputValuesOffset = createAndAllocateTensors(
      trainingSessionId,
      outputIndices,
      outputTensors,
      outputTensorHandles,
      inputOutputAllocs,
      inputCount,
    );

    if (wasm._OrtTrainingEvalStep) {
      const errorCode = wasm._OrtTrainingEvalStep(
        trainingSessionId,
        inputValuesOffset,
        inputCount,
        outputValuesOffset,
        outputCount,
        runOptionsHandle,
      );

      ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer');
    } else {
      throw new Error(NO_TRAIN_FUNCS_MSG);
    }

    return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
  } finally {
    wasm.stackRestore(beforeRunStack);

    inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
    outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
    inputOutputAllocs.forEach((p) => wasm._free(p));

    if (runOptionsHandle !== 0) {
      wasm._OrtReleaseRunOptions(runOptionsHandle);
    }
    runOptionsAllocs.forEach((p) => wasm._free(p));
  }
};

export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => {
  const wasm = getInstance();
  const stack = wasm.stackSave();

  try {
    const sizeOffset = wasm.stackAlloc(4);
    if (wasm._OrtTrainingGetParametersSize) {
      const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly);
      ifErrCodeCheckLastError(errorCode, "Can't get parameters size");

      return wasm.HEAP32[sizeOffset / 4];
    } else {
      throw new Error(NO_TRAIN_FUNCS_MSG);
    }
  } finally {
    wasm.stackRestore(stack);
  }
};

export const getContiguousParameters = async (
  trainingSessionId: number,
  trainableOnly: boolean,
): Promise<TensorMetadata> => {
  const wasm = getInstance();
  const stack = wasm.stackSave();

  const tensorTypeAsString = 'float32';
  const locationAsString = 'cpu';

  const parametersSize = getParametersSize(trainingSessionId, trainableOnly);
  let tensor = 0;

  // allocates a buffer of the correct size on the WASM heap
  const paramsByteLength = 4 * parametersSize;
  const paramsOffset = wasm._malloc(paramsByteLength);

  // handles the dimensions-related createTensor parameters
  const dims = [parametersSize];

  const dimsOffset = wasm.stackAlloc(4);
  const dimsIndex = dimsOffset / 4;
  wasm.HEAP32[dimsIndex] = parametersSize;

  try {
    // wraps allocated array in a tensor
    tensor = wasm._OrtCreateTensor(
      tensorDataTypeStringToEnum(tensorTypeAsString),
      paramsOffset,
      paramsByteLength,
      dimsOffset,
      dims.length,
      dataLocationStringToEnum(locationAsString),
    );
    ifErrCodeCheckLastError(
      tensor,
      `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`,
      false,
    );

    if (wasm._OrtTrainingCopyParametersToBuffer) {
      const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly);
      ifErrCodeCheckLastError(errCode, "Can't get contiguous parameters.");
    } else {
      throw new Error(NO_TRAIN_FUNCS_MSG);
    }

    // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object
    const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString);
    const data = new typedArrayConstructor(parametersSize);
    const output: TensorMetadata[] = [];
    new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
      wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength),
    );
    output.push([tensorTypeAsString, dims, data, locationAsString]);
    if (output.length !== 1) {
      throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of
     one, got ${output.length}`);
    } else {
      return output[0];
    }
  } finally {
    if (tensor !== 0) {
      wasm._OrtReleaseTensor(tensor);
    }
    wasm._free(paramsOffset);
    wasm._free(dimsOffset);
    wasm.stackRestore(stack);
  }
};

export const loadParametersBuffer = async (
  trainingSessionId: number,
  buffer: Uint8Array,
  trainableOnly: boolean,
): Promise<void> => {
  const wasm = getInstance();
  const stack = wasm.stackSave();

  const tensorTypeAsString = 'float32';
  const locationAsString = 'cpu';

  // allocates & copies JavaScript buffer to WASM heap
  const bufferByteLength = buffer.length;
  const bufferCount = bufferByteLength / 4;
  const bufferOffset = wasm._malloc(bufferByteLength);
  wasm.HEAPU8.set(buffer, bufferOffset);

  // allocates and handles moving dimensions information to WASM memory
  const dimsOffset = wasm.stackAlloc(4);
  wasm.HEAP32[dimsOffset / 4] = bufferCount;
  const dimsLength = 1;
  let tensor = 0;

  try {
    tensor = wasm._OrtCreateTensor(
      tensorDataTypeStringToEnum(tensorTypeAsString),
      bufferOffset,
      bufferByteLength,
      dimsOffset,
      dimsLength,
      dataLocationStringToEnum(locationAsString),
    );
    ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false);

    if (wasm._OrtTrainingCopyParametersFromBuffer) {
      const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly);
      ifErrCodeCheckLastError(errCode, "Can't copy buffer to parameters.");
    } else {
      throw new Error(NO_TRAIN_FUNCS_MSG);
    }
  } finally {
    if (tensor !== 0) {
      wasm._OrtReleaseTensor(tensor);
    }
    wasm.stackRestore(stack);
    wasm._free(bufferOffset);
    wasm._free(dimsOffset);
  }
};

export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => {
  const wasm = getInstance();

  if (wasm._OrtTrainingReleaseSession) {
    wasm._OrtTrainingReleaseSession(sessionId);
  }
  if (wasm._OrtTrainingReleaseCheckpoint) {
    wasm._OrtTrainingReleaseCheckpoint(checkpointId);
  }
};
